{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\n    @Author: King\\n    @Date: 2019.06.25\\n    @Purpose: Graph Convolutional Network\\n    @Introduction:   This is a gentle introduction of using DGL to implement \\n                    Graph Convolutional Networks (Kipf & Welling et al., \\n                    Semi-Supervised Classification with Graph Convolutional Networks). \\n                    We build upon the earlier tutorial on DGLGraph and demonstrate how DGL \\n                    combines graph with deep neural network and learn structural representations.\\n    @Datasets: \\n    @Link : \\n    @Reference : https://docs.dgl.ai/tutorials/models/1_gnn/1_gcn.html\\n'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'''\n",
    "    @Author: King\n",
    "    @Date: 2019.06.25\n",
    "    @Purpose: Graph Convolutional Network\n",
    "    @Introduction:   This is a gentle introduction of using DGL to implement \n",
    "                    Graph Convolutional Networks (Kipf & Welling et al., \n",
    "                    Semi-Supervised Classification with Graph Convolutional Networks). \n",
    "                    We build upon the earlier tutorial on DGLGraph and demonstrate how DGL \n",
    "                    combines graph with deep neural network and learn structural representations.\n",
    "    @Datasets: \n",
    "    @Link : \n",
    "    @Reference : https://docs.dgl.ai/tutorials/models/1_gnn/1_gcn.html\n",
    "'''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Relational Graph Convolutional Network\n",
    "\n",
    "A knowledge graph is made up by a collection of triples of the form (subject, relation, object). Edges thus encode important information and have their own embeddings to be learned. Furthermore, there may exist multiple edges among any given pair.\n",
    "\n",
    "A recent model Relational-GCN (R-GCN) from the paper Modeling Relational Data with Graph Convolutional Networks is one effort to generalize GCN to handle different relations between entities in knowledge base. This tutorial shows how to implement R-GCN with DGL.\n",
    "最近的模型Relational-GCN（R-GCN）来自使用图卷积网络建模关系数据的论文，是一种推广GCN以处理知识库中实体之间的不同关系的努力。本教程介绍如何使用DGL实现R-GCN。\n",
    "\n",
    "\n",
    "## R-GCN: a brief introduction\n",
    "\n",
    "In statistical relational learning (SRL), there are two fundamental tasks:\n",
    "\n",
    "- Entity classification, i.e., assign types and categorical properties to entities.\n",
    "- Link prediction, i.e., recover missing triples.\n",
    "\n",
    "In both cases, missing information are expected to be recovered from neighborhood structure of the graph. Here is the example from the R-GCN paper:\n",
    "\n",
    "“Knowing that Mikhail Baryshnikov was educated at the Vaganova Academy implies both that Mikhail Baryshnikov should have the label person, and that the triple (Mikhail Baryshnikov, lived in, Russia) must belong to the knowledge graph.”\n",
    "\n",
    "R-GCN solves these two problems using a common graph convolutional network extended with multi-edge encoding to compute embedding of the entities, but with different downstream processing:\n",
    "\n",
    "- Entity classification is done by attaching a softmax classifier at the final embedding of an entity (node). Training is through loss of standard cross-entropy.\n",
    "(实体分类是通过在实体（节点）的最终嵌入处附加softmax分类器来完成的。训练是通过丧失标准交叉熵。)\n",
    "\n",
    "- Link prediction is done by reconstructing an edge with an autoencoder architecture, using a parameterized score function. Training uses negative sampling.\n",
    "(通过使用参数化分数函数利用自动编码器架构重建边缘来完成链路预测。培训使用负面抽样。)\n",
    "\n",
    "This tutorial will focus on the first task to show how to generate entity representation. Complete code for both tasks can be found in DGL’s github repository.\n",
    "\n",
    "## Key ideas of R-GCN\n",
    "\n",
    "![](img/RGCN1.png)\n",
    "![](img/RGCN2.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Implement R-GCN in DGL\n",
    "\n",
    "An R-GCN model is composed of several R-GCN layers. The first R-GCN layer also serves as input layer and takes in features (e.g. description texts) associated with node entity and project to hidden space. In this tutorial, we only use entity id as entity feature.\n",
    "\n",
    "### R-GCN Layers\n",
    "\n",
    "For each node, an R-GCN layer performs the following steps:\n",
    "\n",
    "- Compute outgoing message using node representation and weight matrix associated with the edge type (message function)\n",
    "- Aggregate incoming messages and generate new node representations (reduce and apply function)\n",
    "\n",
    "The following is the definition of an R-GCN hidden layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from dgl import DGLGraph\n",
    "import dgl.function as fn\n",
    "from functools import partial\n",
    "\n",
    "class RGCNLayer(nn.Module):\n",
    "    def __init__(self, in_feat, out_feat, num_rels, num_bases=-1, bias=None,\n",
    "                 activation=None, is_input_layer=False):\n",
    "        super(RGCNLayer, self).__init__()\n",
    "        self.in_feat = in_feat\n",
    "        self.out_feat = out_feat\n",
    "        self.num_rels = num_rels\n",
    "        self.num_bases = num_bases\n",
    "        self.bias = bias\n",
    "        self.activation = activation\n",
    "        self.is_input_layer = is_input_layer\n",
    "\n",
    "        # sanity check\n",
    "        if self.num_bases <= 0 or self.num_bases > self.num_rels:\n",
    "            self.num_bases = self.num_rels\n",
    "\n",
    "        # weight bases in equation (3)\n",
    "        self.weight = nn.Parameter(torch.Tensor(self.num_bases, self.in_feat,\n",
    "                                                self.out_feat))\n",
    "        if self.num_bases < self.num_rels:\n",
    "            # linear combination coefficients in equation (3)\n",
    "            self.w_comp = nn.Parameter(torch.Tensor(self.num_rels, self.num_bases))\n",
    "\n",
    "        # add bias\n",
    "        if self.bias:\n",
    "            self.bias = nn.Parameter(torch.Tensor(out_feat))\n",
    "\n",
    "        # init trainable parameters\n",
    "        nn.init.xavier_uniform_(self.weight,\n",
    "                                gain=nn.init.calculate_gain('relu'))\n",
    "        if self.num_bases < self.num_rels:\n",
    "            nn.init.xavier_uniform_(self.w_comp,\n",
    "                                    gain=nn.init.calculate_gain('relu'))\n",
    "        if self.bias:\n",
    "            nn.init.xavier_uniform_(self.bias,\n",
    "                                    gain=nn.init.calculate_gain('relu'))\n",
    "\n",
    "    def forward(self, g):\n",
    "        if self.num_bases < self.num_rels:\n",
    "            # generate all weights from bases (equation (3))\n",
    "            weight = self.weight.view(self.in_feat, self.num_bases, self.out_feat)\n",
    "            weight = torch.matmul(self.w_comp, weight).view(self.num_rels,\n",
    "                                                        self.in_feat, self.out_feat)\n",
    "        else:\n",
    "            weight = self.weight\n",
    "\n",
    "        if self.is_input_layer:\n",
    "            def message_func(edges):\n",
    "                # for input layer, matrix multiply can be converted to be\n",
    "                # an embedding lookup using source node id\n",
    "                embed = weight.view(-1, self.out_feat)\n",
    "                index = edges.data['rel_type'] * self.in_feat + edges.src['id']\n",
    "                return {'msg': embed[index] * edges.data['norm']}\n",
    "        else:\n",
    "            def message_func(edges):\n",
    "                w = weight[edges.data['rel_type']]\n",
    "                msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze()\n",
    "                msg = msg * edges.data['norm']\n",
    "                return {'msg': msg}\n",
    "\n",
    "        def apply_func(nodes):\n",
    "            h = nodes.data['h']\n",
    "            if self.bias:\n",
    "                h = h + self.bias\n",
    "            if self.activation:\n",
    "                h = self.activation(h)\n",
    "            return {'h': h}\n",
    "\n",
    "        g.update_all(message_func, fn.sum(msg='msg', out='h'), apply_func)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define full R-GCN model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model(nn.Module):\n",
    "    def __init__(self, num_nodes, h_dim, out_dim, num_rels,\n",
    "                 num_bases=-1, num_hidden_layers=1):\n",
    "        super(Model, self).__init__()\n",
    "        self.num_nodes = num_nodes\n",
    "        self.h_dim = h_dim\n",
    "        self.out_dim = out_dim\n",
    "        self.num_rels = num_rels\n",
    "        self.num_bases = num_bases\n",
    "        self.num_hidden_layers = num_hidden_layers\n",
    "\n",
    "        # create rgcn layers\n",
    "        self.build_model()\n",
    "\n",
    "        # create initial features\n",
    "        self.features = self.create_features()\n",
    "\n",
    "    def build_model(self):\n",
    "        self.layers = nn.ModuleList()\n",
    "        # input to hidden\n",
    "        i2h = self.build_input_layer()\n",
    "        self.layers.append(i2h)\n",
    "        # hidden to hidden\n",
    "        for _ in range(self.num_hidden_layers):\n",
    "            h2h = self.build_hidden_layer()\n",
    "            self.layers.append(h2h)\n",
    "        # hidden to output\n",
    "        h2o = self.build_output_layer()\n",
    "        self.layers.append(h2o)\n",
    "\n",
    "    # initialize feature for each node\n",
    "    def create_features(self):\n",
    "        features = torch.arange(self.num_nodes)\n",
    "        return features\n",
    "\n",
    "    def build_input_layer(self):\n",
    "        return RGCNLayer(self.num_nodes, self.h_dim, self.num_rels, self.num_bases,\n",
    "                         activation=F.relu, is_input_layer=True)\n",
    "\n",
    "    def build_hidden_layer(self):\n",
    "        return RGCNLayer(self.h_dim, self.h_dim, self.num_rels, self.num_bases,\n",
    "                         activation=F.relu)\n",
    "\n",
    "    def build_output_layer(self):\n",
    "        return RGCNLayer(self.h_dim, self.out_dim, self.num_rels, self.num_bases,\n",
    "                         activation=partial(F.softmax, dim=1))\n",
    "\n",
    "    def forward(self, g):\n",
    "        if self.features is not None:\n",
    "            g.ndata['id'] = self.features\n",
    "        for layer in self.layers:\n",
    "            layer(g)\n",
    "        return g.ndata.pop('h')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Handle dataset\n",
    "\n",
    "In this tutorial, we use AIFB dataset from R-GCN paper:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading C:\\Users\\ASUS\\.dgl\\aifb.tgz from https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/aifb.tgz...\n",
      "Extracting file to C:\\Users\\ASUS\\.dgl\\aifb\n",
      "Loading dataset aifb\n",
      "Graph loaded, frequencies counted.\n",
      "Number of nodes:  8285\n",
      "Number of relations:  91\n",
      "Number of edges:  66371\n",
      "4 classes: {'http://www.aifb.uni-karlsruhe.de/Forschungsgruppen/viewForschungsgruppeOWL/id2instance', 'http://www.aifb.uni-karlsruhe.de/Forschungsgruppen/viewForschungsgruppeOWL/id3instance', 'http://www.aifb.uni-karlsruhe.de/Forschungsgruppen/viewForschungsgruppeOWL/id4instance', 'http://www.aifb.uni-karlsruhe.de/Forschungsgruppen/viewForschungsgruppeOWL/id1instance'}\n",
      "Loading training set\n",
      "Loading test set\n",
      "Number of classes:  4\n",
      "removing nodes that are more than 3 hops away\n"
     ]
    }
   ],
   "source": [
    "# load graph data\n",
    "from dgl.contrib.data import load_data\n",
    "import numpy as np\n",
    "data = load_data(dataset='aifb')\n",
    "num_nodes = data.num_nodes\n",
    "num_rels = data.num_rels\n",
    "num_classes = data.num_classes\n",
    "labels = data.labels\n",
    "train_idx = data.train_idx\n",
    "# split training and validation set\n",
    "val_idx = train_idx[:len(train_idx) // 5]\n",
    "train_idx = train_idx[len(train_idx) // 5:]\n",
    "\n",
    "# edge type and normalization factor\n",
    "edge_type = torch.from_numpy(data.edge_type)\n",
    "edge_norm = torch.from_numpy(data.edge_norm).unsqueeze(1)\n",
    "\n",
    "labels = torch.from_numpy(labels).view(-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create graph and model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# configurations\n",
    "n_hidden = 16 # number of hidden units\n",
    "n_bases = -1 # use number of relations as number of bases\n",
    "n_hidden_layers = 0 # use 1 input layer, 1 output layer, no hidden layer\n",
    "n_epochs = 25 # epochs to train\n",
    "lr = 0.01 # learning rate\n",
    "l2norm = 0 # L2 norm coefficient\n",
    "\n",
    "# create graph\n",
    "g = DGLGraph()\n",
    "g.add_nodes(num_nodes)\n",
    "g.add_edges(data.edge_src, data.edge_dst)\n",
    "g.edata.update({'rel_type': edge_type, 'norm': edge_norm})\n",
    "\n",
    "# create model\n",
    "model = Model(len(g),\n",
    "              n_hidden,\n",
    "              num_classes,\n",
    "              num_rels,\n",
    "              num_bases=n_bases,\n",
    "              num_hidden_layers=n_hidden_layers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "start training...\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "Expected object of type torch.IntTensor but found type torch.LongTensor for argument #3 'other'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-10-1f25a96b004e>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      6\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mn_epochs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      7\u001b[0m     \u001b[0moptimizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 8\u001b[1;33m     \u001b[0mlogits\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      9\u001b[0m     \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcross_entropy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlogits\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mtrain_idx\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mtrain_idx\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     10\u001b[0m     \u001b[0mloss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-3-5bb501b869e9>\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, g)\u001b[0m\n\u001b[0;32m     50\u001b[0m             \u001b[0mg\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndata\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'id'\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     51\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mlayer\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 52\u001b[1;33m             \u001b[0mlayer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     53\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0mg\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'h'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\progrom\\python\\python\\python3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m    475\u001b[0m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    476\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 477\u001b[1;33m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    478\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    479\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-2-da794e5e64d3>\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, g)\u001b[0m\n\u001b[0;32m     74\u001b[0m             \u001b[1;32mreturn\u001b[0m \u001b[1;33m{\u001b[0m\u001b[1;34m'h'\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mh\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     75\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 76\u001b[1;33m         \u001b[0mg\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mupdate_all\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmessage_func\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msum\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'msg'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mout\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'h'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mapply_func\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32mD:\\progrom\\python\\python\\python3\\lib\\site-packages\\dgl\\graph.py\u001b[0m in \u001b[0;36mupdate_all\u001b[1;34m(self, message_func, reduce_func, apply_node_func)\u001b[0m\n\u001b[0;32m   2549\u001b[0m                                           \u001b[0mreduce_func\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mreduce_func\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2550\u001b[0m                                           apply_func=apply_node_func)\n\u001b[1;32m-> 2551\u001b[1;33m             \u001b[0mRuntime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprog\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   2552\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2553\u001b[0m     def prop_nodes(self,\n",
      "\u001b[1;32mD:\\progrom\\python\\python\\python3\\lib\\site-packages\\dgl\\runtime\\runtime.py\u001b[0m in \u001b[0;36mrun\u001b[1;34m(prog)\u001b[0m\n\u001b[0;32m      8\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mexe\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mprog\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexecs\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      9\u001b[0m             \u001b[1;31m#prog.pprint_exe(exe)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 10\u001b[1;33m             \u001b[0mexe\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32mD:\\progrom\\python\\python\\python3\\lib\\site-packages\\dgl\\runtime\\ir\\executor.py\u001b[0m in \u001b[0;36mrun\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    201\u001b[0m         \u001b[0medge_data\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfdedge\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    202\u001b[0m         \u001b[0mdst_data\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfddst\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 203\u001b[1;33m         \u001b[0mudf_ret\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mfn_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msrc_data\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0medge_data\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdst_data\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    204\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mret\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mFrameRef\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mFrame\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mudf_ret\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    205\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\progrom\\python\\python\\python3\\lib\\site-packages\\dgl\\runtime\\scheduler.py\u001b[0m in \u001b[0;36m_mfunc_wrapper\u001b[1;34m(src_data, edge_data, dst_data)\u001b[0m\n\u001b[0;32m    831\u001b[0m         ebatch = EdgeBatch(graph, (u.data, v.data, eid.data),\n\u001b[0;32m    832\u001b[0m                            src_data, edge_data, dst_data)\n\u001b[1;32m--> 833\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[0mmfunc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mebatch\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    834\u001b[0m     \u001b[0m_mfunc_wrapper\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mvar\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mFUNC\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_mfunc_wrapper\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    835\u001b[0m     \u001b[0mmsg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mir\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mEDGE_UDF\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_mfunc_wrapper\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfdsrc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfdedge\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfddst\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-2-da794e5e64d3>\u001b[0m in \u001b[0;36mmessage_func\u001b[1;34m(edges)\u001b[0m\n\u001b[0;32m     57\u001b[0m                 \u001b[1;31m# an embedding lookup using source node id\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     58\u001b[0m                 \u001b[0membed\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mweight\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mview\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mout_feat\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 59\u001b[1;33m                 \u001b[0mindex\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0medges\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'rel_type'\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0min_feat\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0medges\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msrc\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'id'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     60\u001b[0m                 \u001b[1;32mreturn\u001b[0m \u001b[1;33m{\u001b[0m\u001b[1;34m'msg'\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0membed\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mindex\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0medges\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'norm'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     61\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mRuntimeError\u001b[0m: Expected object of type torch.IntTensor but found type torch.LongTensor for argument #3 'other'"
     ]
    }
   ],
   "source": [
    "# optimizer\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2norm)\n",
    "\n",
    "print(\"start training...\")\n",
    "model.train()\n",
    "for epoch in range(n_epochs):\n",
    "    optimizer.zero_grad()\n",
    "    logits = model.forward(g)\n",
    "    loss = F.cross_entropy(logits[train_idx], labels[train_idx])\n",
    "    loss.backward()\n",
    "\n",
    "    optimizer.step()\n",
    "\n",
    "    train_acc = torch.sum(logits[train_idx].argmax(dim=1) == labels[train_idx])\n",
    "    train_acc = train_acc.item() / len(train_idx)\n",
    "    val_loss = F.cross_entropy(logits[val_idx], labels[val_idx])\n",
    "    val_acc = torch.sum(logits[val_idx].argmax(dim=1) == labels[val_idx])\n",
    "    val_acc = val_acc.item() / len(val_idx)\n",
    "    print(\"Epoch {:05d} | \".format(epoch) +\n",
    "          \"Train Accuracy: {:.4f} | Train Loss: {:.4f} | \".format(\n",
    "              train_acc, loss.item()) +\n",
    "          \"Validation Accuracy: {:.4f} | Validation loss: {:.4f}\".format(\n",
    "              val_acc, val_loss.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
